package com.jstarcraft.ai.jsat.classifiers.bayesian;

import com.jstarcraft.ai.jsat.classifiers.BaseUpdateableClassifier;
import com.jstarcraft.ai.jsat.classifiers.CategoricalData;
import com.jstarcraft.ai.jsat.classifiers.CategoricalResults;
import com.jstarcraft.ai.jsat.classifiers.ClassificationDataSet;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.exceptions.FailedToFitException;
import com.jstarcraft.ai.jsat.utils.concurrent.ParallelUtils;

/**
 * Averaged One-Dependence Estimators (AODE) is an extension of Naive Bayes that
 * attempts to be more accurate by reducing the independence assumption. For
 * <i>n</i> data points with <i>d</i> categorical features, <i>d</i> {@link ODE}
 * classifier are created, each with a dependence on a different attribute. The
 * results of these classifiers is averaged to produce a final result. The
 * construction time is <i>O(n d<sup>2</sup>)</i>. Because of this extra
 * dependence requirement, the implementation only allows for categorical
 * features. <br>
 * <br>
 * See: Webb, G., &amp; Boughton, J. (2005). <i>Not so naive bayes: Aggregating
 * one-dependence estimators</i>. Machine Learning, 1–24. Retrieved from
 * <a href="http://www.springerlink.com/index/U8W306673M1P866K.pdf">here</a>
 * 
 * @author Edward Raff
 */
public class AODE extends BaseUpdateableClassifier {

    private static final long serialVersionUID = 8386506277969540732L;
    protected CategoricalData predicting;
    protected ODE[] odes;

    /**
     * The minimum value to use a probability
     */
    private double m = 20;

    /**
     * Creates a new AODE classifier.
     */
    public AODE() {
    }

    /**
     * Creates a copy of an AODE classifier
     * 
     * @param toClone the classifier to clone
     */
    protected AODE(AODE toClone) {
        if (toClone.odes != null) {
            this.odes = new ODE[toClone.odes.length];
            for (int i = 0; i < this.odes.length; i++)
                this.odes[i] = toClone.odes[i].clone();
            this.predicting = toClone.predicting.clone();
        }
        this.m = toClone.m;
    }

    @Override
    public AODE clone() {
        return new AODE(this);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        if (categoricalAttributes.length < 1)
            throw new FailedToFitException("At least 2 categorical varaibles are needed for AODE");
        this.predicting = predicting;
        odes = new ODE[categoricalAttributes.length];

        for (int i = 0; i < odes.length; i++) {
            odes[i] = new ODE(i);
            odes[i].setUp(categoricalAttributes, numericAttributes, predicting);
        }
    }

    @Override
    public void train(final ClassificationDataSet dataSet, boolean parallel) {
        setUp(dataSet.getCategories(), dataSet.getNumNumericalVars(), dataSet.getPredicting());

        ParallelUtils.range(odes.length, parallel).forEach(z -> {
            ODE ode = odes[z];
            for (int i = 0; i < dataSet.size(); i++)
                ode.update(dataSet.getDataPoint(i), dataSet.getWeight(i), dataSet.getDataPointCategory(i));
        });
    }

    @Override
    public void update(DataPoint dataPoint, double weight, int targetClass) {
        for (ODE ode : odes)
            ode.update(dataPoint, targetClass);
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(predicting.getNumOfCategories());

        int[] catVals = data.getCategoricalValues();
        for (int c = 0; c < cr.size(); c++) {
            double prob = 0.0;
            for (ODE ode : odes)
                if (ode.priors[c][catVals[ode.dependent]] < m)
                    continue;
                else
                    prob += Math.exp(ode.getLogPrb(catVals, c));
            cr.setProb(c, prob);
        }
        cr.normalize();

        return cr;
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    /**
     * Sets the minimum prior observation value needed for an attribute combination
     * to have enough support to be included in the final estimate.
     * 
     * @param m the minimum needed score
     */
    public void setM(double m) {
        if (m < 0 || Double.isInfinite(m) || Double.isNaN(m))
            throw new ArithmeticException("The minimum count must be a non negative number");
        this.m = m;
    }

    /**
     * Returns the minimum needed score
     * 
     * @return the minimum needed score
     */
    public double getM() {
        return m;
    }

}
